// Copyright 2010-2025 Google LLC
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "ortools/math_opt/elemental/codegen/gen_python.h"

#include <memory>
#include <set>
#include <string>

#include "absl/strings/ascii.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "ortools/math_opt/elemental/codegen/gen.h"

namespace operations_research::math_opt::codegen {

namespace {

const AttrOpFunctionInfos* GetPythonFunctionInfos() {
  // We're not generating functions for python, only enums.
  static const auto* const kResult = new AttrOpFunctionInfos();
  return kResult;
}

// Emits a set of numbered python enumerators for the given range.
void EmitEnumerators(const absl::Span<const absl::string_view> names,
                     std::string* out) {
  for (int i = 0; i < names.size(); ++i) {
    absl::StrAppendFormat(out, "  %s = %i\n", absl::AsciiStrToUpper(names[i]),
                          i);
  }
}

// Returns the python type for the given value type.
absl::string_view GetAttrPyValueType(
    const CodegenAttrTypeDescriptor::ValueType& value_type) {
  switch (value_type) {
    case CodegenAttrTypeDescriptor::ValueType::kBool:
      return "bool";
    case CodegenAttrTypeDescriptor::ValueType::kInt64:
      return "int";
    case CodegenAttrTypeDescriptor::ValueType::kDouble:
      return "float";
  }
}

// Returns the python type for the given value type.
absl::string_view GetAttrNumpyValueType(
    const CodegenAttrTypeDescriptor::ValueType& value_type) {
  switch (value_type) {
    case CodegenAttrTypeDescriptor::ValueType::kBool:
      return "np.bool_";
    case CodegenAttrTypeDescriptor::ValueType::kInt64:
      return "np.int64";
    case CodegenAttrTypeDescriptor::ValueType::kDouble:
      return "np.float64";
  }
}

class PythonEnumsGenerator : public CodeGenerator {
 public:
  PythonEnumsGenerator() : CodeGenerator(GetPythonFunctionInfos()) {}

  void EmitHeader(std::string* out) const override {
    absl::StrAppend(out, R"(
'''DO NOT EDIT: This file is autogenerated.'''

import enum
from typing import Generic, TypeVar, Union

import numpy as np
)");
  }

  void EmitElements(absl::Span<const absl::string_view> elements,
                    std::string* out) const override {
    // Generate an enum for the elements.
    absl::StrAppend(out, "class ElementType(enum.Enum):\n");
    EmitEnumerators(elements, out);
    absl::StrAppend(out, "\n");
  }

  void EmitAttributes(absl::Span<const CodegenAttrTypeDescriptor> descriptors,
                      std::string* out) const override {
    absl::StrAppend(out, "\n");

    {
      // Collect the list of unique types:
      std::set<absl::string_view> value_types;
      for (const auto& descriptor : descriptors) {
        value_types.insert(GetAttrNumpyValueType(descriptor.value_type));
      }

      // Emit `AttrValueType`, a type variable for all attribute value types.
      absl::StrAppend(out, "AttrValueType = TypeVar('AttrValueType', ",
                      absl::StrJoin(value_types, ", "), ")\n");
    }
    absl::StrAppend(out, "\n");
    {
      std::set<absl::string_view> py_value_types;
      for (const auto& descriptor : descriptors) {
        py_value_types.insert(GetAttrPyValueType(descriptor.value_type));
      }
      absl::StrAppend(out, "AttrPyValueType = TypeVar('AttrPyValueType', ",
                      absl::StrJoin(py_value_types, ", "), ")\n");
    }

    // `Attr` is an attribute with any value type.
    absl::StrAppend(out, R"(
class Attr(Generic[AttrValueType]):
  pass
)");

    // `PyAttr` is an attribute with any value type.
    absl::StrAppend(out, R"(
class PyAttr(Generic[AttrPyValueType]):
  pass
)");

    // Generate an enum for the attribute type.
    for (const auto& descriptor : descriptors) {
      absl::StrAppendFormat(
          out, "\nclass %s(Attr[%s], PyAttr[%s], int, enum.Enum):\n",
          descriptor.name, GetAttrNumpyValueType(descriptor.value_type),
          GetAttrPyValueType(descriptor.value_type));
      EmitEnumerators(descriptor.attribute_names, out);
      absl::StrAppend(out, "\n");
    }

    // Add a type alias for the union of all attribute types.
    absl::StrAppend(
        out, "AnyAttr = Union[",
        absl::StrJoin(
            descriptors, ", ",
            [](std::string* out, const CodegenAttrTypeDescriptor& descriptor) {
              absl::StrAppend(out, descriptor.name);
            }),
        "]\n");
  }
};

}  // namespace

std::unique_ptr<CodeGenerator> PythonEnums() {
  return std::make_unique<PythonEnumsGenerator>();
}

}  // namespace operations_research::math_opt::codegen
